classdef IGEA < ALGORITHM
    % IGEA - An Algorithm Class for implementing Information Gain based Evolutionary Algorithm
    % This class defines the structure and main operations of an evolutionary algorithm
    % that utilizes information gain as a weighting mechanism for crossover and mutation.

    methods
        function main(Algorithm, Problem)
            % Main method of the algorithm
            % It sets up the algorithm parameters, initializes the population,
            % and iterates through the main loop until termination criteria are met.

            % Setting parameters (The save path for results)
            [save_path] = Algorithm.ParameterSet('./');

            % Preparing training and label data from the problem instance
            train = Problem.TrainIn;
            label = Problem.TrainOut;

            % Calculate Information Gain for each feature and normalize it to get weights for crossover and mutation
            Information_data = [train, label];
            IG = Information_gain(Information_data);
            W_IG = (IG - min(IG)) ./ (max(IG) - min(IG));

            % Initialize population
            Population = InitialPop(Problem, 1);

            % Main loop of the algorithm
            while Algorithm.NotTerminated(Population)
                % Generate offspring using variation operators and information gain weights
                Off = NicVariation(Population, W_IG);

                % Environmental selection to form the next generation
                Population = EnvironmentalSelection([Population, Off], Problem.N);
            end

            % Handling results
            final = Population.decs;
            p = zeros(Problem.N, 1);
            for idx = 1:Problem.N
                s = logical(final(idx, :));
                % Ensure at least one feature is selected
                while sum(s) == 0
                    s = logical(round(rand(1, size(s, 2))));
                end
                % Fit KNN model and predict validation set
                mdltest = fitcknn(train(:, s), label, 'NumNeighbors', 5);
                c = predict(mdltest, Problem.ValidIn(:, s));
                % Calculate prediction accuracy
                p(idx, 1) = 1 - length(find(c == Problem.ValidOut)) / length(c);
            end

            % Compile results and write to a CSV file
            p = [Population.objs, p, final];
            dir_name = class(Algorithm);
            xls_name = strcat(dir_name, '-', Problem.str{Problem.problem_id}, '_', num2str(Problem.number), '.csv');
            writematrix(p, fullfile(save_path, xls_name));
        end
    end
end
